import torch
from PIL import Image
from modelscope import AutoModelForCausalLM, AutoTokenizer
from modelscope import snapshot_download
import os
from bdtime import tt


cache_dir = os.path.join("..", "models")
os.makedirs(cache_dir, exist_ok=True)

local_cache_dir = os.path.join(cache_dir, 'local')
os.makedirs(local_cache_dir, exist_ok=True)

device = "cuda"

# model_id = "ZhipuAI/glm-4v-9b"
# tokenizer = AutoTokenizer.from_pretrained(
#     model_id,
#     trust_remote_code=True,
#     cache_dir=cache_dir,
#     local_files_only=True,
# )

model_name = "ZhipuAI/glm-4v-9b"
model_dir = snapshot_download(model_name, cache_dir=cache_dir, local_files_only=True)

trust_remote_code = True
# trust_remote_code = False


local_model_dir = os.path.join(local_cache_dir, model_name)
os.makedirs(local_model_dir, exist_ok=True)

local_tokenizer_path = os.path.join(local_model_dir, 'tokenizer')
local_model_path = os.path.join(local_model_dir, 'model')

print(f'------- os.path.exists(local_model_path): {os.path.exists(local_model_path)}')
if not os.path.exists(local_model_path):
    local_tokenizer_path = model_dir
    local_model_path = model_dir
else:
    print(f'--- 从本地加载模型! local_model_dir: {local_model_dir}')

tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path, trust_remote_code=trust_remote_code)
model = AutoModelForCausalLM.from_pretrained(
    local_model_path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=trust_remote_code
).to(device).eval()

if not os.path.exists(local_model_path):
    print(f'------ 保存`tokenizer`和`model`到: {local_model_dir}')
    tokenizer.save_pretrained(local_tokenizer_path)
    model.save_pretained(local_model_path)

# model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda()

query = '描述这张图片'

img_file_path = 'images/cat.png'
image = Image.open(img_file_path).convert('RGB')
print(f'--- type(image): {type(image)}')

inputs = tokenizer.apply_chat_template([{"role": "user", "image": image, "content": query}],
                                       add_generation_prompt=True, tokenize=True, return_tensors="pt",
                                       return_dict=True)  # chat mode
inputs = inputs.to(device)


gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}

tt.__init__()
print('--- start generate')

run_times = 5
with torch.no_grad():
    from bdtime.with_timer import with_timer

    with with_timer('测试', tt) as wt:
        # for i in range(10):
        #     tt.sleep(0.3)
        #     if i % 5 == 0:
        #         wt.show(f"第{i}次的loss: {i * 2 / 5}")
        for i in range(run_times):
            outputs = model.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs['input_ids'].shape[1]:]
            if i == 0:
                print("*** ouput:", tokenizer.decode(outputs[0]))
            wt.show(f"第{i}次", reset_cost=True)

print(f'--- total_cost_time: {tt.now()}, mean_cost_time: {tt.now() / run_times : .3f}')
